# core/formalization/rl/rl_manager.py
import os
import torch
import json
import numpy as np
import datetime
from tqdm import tqdm
import random
from typing import Dict, List, Any, Optional, Tuple

from core.formalization.action_space import (
    ActionType,
    FormalizationActionSpace,
    get_action_type_by_index,
)
from core.formalization.symbol_manager import SymbolManager
from core.formalization.rl.state import State, STATE_DIM
from core.formalization.rl.rl_agent import PPO
from core.formalization.rl.exp import StepExp
from core.formalization.action.symbolic_abstraction_action import SymbolicAbstractionAction
from core.formalization.action.logical_encoding_action import LogicalEncodingAction
from core.formalization.action.math_repr_action import MathReprAction
from core.formalization.action.domain_spec_action import DomainSpecAction
from core.formalization.action.strategic_decomp_action import StrategicDecompAction
from core.formalization.action.metaphor_action import MetaphorAction
from core.formalization.action.fallback_action import FallbackAction
from utils.logger import Logger
from llm.llm_wrapper import LLMWrapper
from llm.auxiliary import Auxiliary
import core.agent_prompt as AgentPrompt
from llm.message import (
    Message,
    MessageContent,
    ROLE_SYSTEM,
    ROLE_USER,
    TYPE_SETTING,
    TYPE_CONTENT,
)


class RLManager:

    def __init__(
        self,
        logger: Logger,
        llm: LLMWrapper,
        auxiliary: Auxiliary,
        symbol_manager: SymbolManager,
        config: Dict = {},
    ):
        self.logger = logger
        self.llm = llm
        self.auxiliary = auxiliary
        self.config = config

        self.symbol_manager = symbol_manager
        self.action_space = FormalizationActionSpace(logger, llm)
        self._init_actions(logger, llm, self.symbol_manager)
        self.agent = self._init_model(logger, llm, auxiliary, config)

        self.max_steps = config.get("max_steps", 10)
        self.batch_size = config.get("batch_size", 4)
        self.update_epochs = config.get("update_epochs", 8)

    def _init_actions(self, logger, llm, symbol_manager):
        self.action_space.register_actions(
            [
                FallbackAction(logger, llm, symbol_manager),
                SymbolicAbstractionAction(logger, llm, symbol_manager),
                LogicalEncodingAction(logger, llm, symbol_manager),
                MathReprAction(logger, llm, symbol_manager),
                DomainSpecAction(logger, llm, symbol_manager),
                StrategicDecompAction(logger, llm, symbol_manager),
                MetaphorAction(logger, llm, symbol_manager),
            ]
        )

    def _init_model(self, logger, llm, auxiliary, config):
        self.logger.info("RL manager start to init model")
        agent = PPO(
            logger=logger,
            llm=llm,
            auxiliary=auxiliary,
            state_dim=STATE_DIM,
            action_dim=self.action_space.get_actions_count(),
            actions=self.action_space.get_all_actions(),
            config=config,
        )
        filename = self.config.get("model_name", None)
        if not filename:
            return agent

        model_filepath = self._get_model_filepath(filename)
        if os.path.exists(model_filepath):
            agent.load_model(model_filepath)
        return agent
    
    def _direct_ask(self, query, target):
        system_prompt = f"As an all-knowing expert, answer user's instruction."
        user_prompt = f"User instruction: {query}"
        messages = [
                Message(ROLE_SYSTEM, [MessageContent(TYPE_SETTING, system_prompt)]),
                Message(ROLE_USER, [MessageContent(TYPE_CONTENT, user_prompt)]),
            ]
        response = self.llm.generate(messages)
        self.logger.info(f"Directly response: {response}")
        reward_info = self.agent.compute_reward(
            query, target, response, query, "", 0
        )
        record_info = {
            "original_query": query,
            "step_count": 0,
            "cur_query": query,
            "response": response,
            "reward_info": reward_info,
            "exp_factor": -1,
            "action_mask": None,
            "action": None,
        }
        self._record_history(record_info)
        return response, reward_info

    def exe_episode(self, original_query: str, target, category: str):
        cur_query = original_query
        step_count = 0
        done = False
        trajectory = []

        self.logger.info(f"Start to direct ask with [{original_query}]")
        init_response, init_reward_info = self._direct_ask(original_query, target)

        success = init_reward_info["success"]
        if success:
            done = True
            return {
                "success": success,
                "cur_query": cur_query,
                "step_count": step_count,
            }

        avg_response_length = len(init_response)
        n_sensitive_words = init_reward_info["n_sensitive_words"]
        
        applied_history = [{
            "query": original_query,
            "response": init_response,
        }]
        interaction_history = {
            "action_history": [],
            "response": init_response,
            "step_count": step_count,
            "avg_response_length": avg_response_length,
            "n_sensitive_words": n_sensitive_words,
        }

        while not done and step_count < self.max_steps:
            try:
                self.logger.info("Start to compute current state")
                state = self.agent.compute_state(
                    original_query, cur_query, interaction_history
                )
                
                self.logger.info("Start to compute action mask")
                action_context = {
                    "applied_history": applied_history,
                    "category": category,
                }
                action_mask = self.agent.compute_action_mask(cur_query, action_context)
                self.logger.info(f"Action mask result {action_mask}")

                if np.sum(action_mask) == 0:
                    self.logger.info("No available actions, ending episode")
                    done = True
                    break

                self.logger.info("Start to select action")
                action_idx, log_prob = self.agent.select_action(
                    state, action_mask, training=True
                )
                action_type = get_action_type_by_index(action_idx)
                self.logger.info(f"Select action: {action_type}")

                action_result = self.action_space.apply_action(action_type, cur_query, action_context)

                action_apply = action_result["success"]
                if not action_apply:
                    self.logger.warning("Action not apply for some reason")
                    continue

                # transformed_info = result['transformed_info']
                new_query = action_result["transformed_text"]
                llm_response = action_result["response"]
                last_response = applied_history[-1].get('response', '')

                reward_info = self.agent.compute_reward(
                    new_query, target, llm_response, original_query, last_response, step_count
                )

                success = reward_info["success"]
                done = success
                reward_components: Dict = reward_info["reward_components"]
                n_sensitive_words = reward_info["n_sensitive_words"]
                with torch.no_grad():
                    state_tensor = torch.FloatTensor(state).unsqueeze(0)
                    reward_components_tensor = torch.FloatTensor(list(reward_components.values())).unsqueeze(0)
                    value = self.agent.ac_network.forward_critic(state_tensor, reward_components_tensor).item()
                exp_reward = self.agent.compute_exp_reward(reward_components)
                total_length = interaction_history.get(
                    "avg_response_length", 0
                ) * step_count + len(llm_response)

                exp = StepExp(
                    state=state,
                    action=action_idx,
                    log_prob=log_prob,
                    value=value,
                    exp_reward=exp_reward,
                    done=done,
                    mask=action_mask,
                    reward_components=reward_components,
                )
                trajectory.append(exp)

                step_count += 1
                cur_query = new_query
                interaction_history["response"] = llm_response
                interaction_history["avg_response_length"] = total_length / (
                    step_count + 1
                )
                if action_type == ActionType.FALLBACK:
                    if len(applied_history) > 1:
                        applied_history = applied_history[:-1]  # remove last one
                    else:
                        self.logger.warning("Cannot fallback with insufficient history")
                else:
                    applied_history.append({
                        "query": new_query,
                        "response": llm_response,
                    })
                interaction_history["action_history"].append({
                    "action": action_type,
                    "query": new_query,
                    "response": llm_response,
                })
                interaction_history["step_count"] = step_count
                interaction_history["n_sensitive_words"] = n_sensitive_words

                record_info = {
                    "original_query": original_query,
                    "step_count": step_count,
                    "cur_query": cur_query,
                    "response": llm_response,
                    "reward_info": reward_info,
                    "exp_factor": self.agent.exp_factor,
                    "action_mask": action_mask.tolist(),
                    "action": action_type.value,
                }

                if success:
                    reduction = self._content_reduction(llm_response)
                    record_info["reduction"] = reduction
                
                self._record_history(record_info)

            except Exception as e:
                self.logger.log_exception(e)

        self.agent.buffer.add_trajectory(trajectory)
        return {
            "done": done,
            "cur_query": cur_query,
            "step_count": step_count,
        }

    def predict(self, original_query: str, target: str, category: str):
        cur_query = original_query
        step_count = 0
        done = False
        result = []

        init_response, init_reward_info = self._direct_ask(original_query, target)

        success = init_reward_info["success"]
        if success:
            done = True
            return [
                {
                    "success": True,
                    "original_query": original_query,
                    "step_count": step_count,
                    "cur_query": cur_query,
                    "response": init_response,
                    "action": None,
                }
            ]
        step_count +=1

        avg_response_length = len(init_response)
        n_sensitive_words = init_reward_info["n_sensitive_words"]
        applied_history = [{
            "query": original_query,
            "response": init_response,
        }]
        interaction_history = {
            "original_query": original_query,
            "action_history": [],
            "response": init_response,
            "step_count": step_count,
            "avg_response_length": avg_response_length,
            "n_sensitive_words": n_sensitive_words,
        }

        while not done and step_count < self.max_steps:

            try:
                state = self.agent.compute_state(
                    original_query, cur_query, interaction_history
                )
                action_context = {
                    "applied_history": applied_history,
                    "category": category,
                }
                action_mask = self.agent.compute_action_mask(cur_query, action_context)

                if np.sum(action_mask) == 0:
                    self.logger.info("No available actions, ending episode")
                    done = True
                    break

                action_idx, _ = self.agent.select_action(
                    state, action_mask, training=True
                )
                action_type = get_action_type_by_index(action_idx)
                action_result = self.action_space.apply_action(action_type, cur_query, action_context)

                action_apply = action_result["success"]
                if not action_apply:
                    self.logger.warning("Action not apply for some reason")
                    continue

                # transformed_info = result['transformed_info']
                new_query = action_result["transformed_text"]
                llm_response = action_result["response"]
                last_response = applied_history[-1].get("response", "")

                reward_info = self.agent.compute_reward(
                    new_query, target, llm_response, original_query, last_response, step_count
                )

                success = reward_info["success"]
                done = success
                n_sensitive_words = reward_info["n_sensitive_words"]

                total_length = interaction_history.get(
                    "avg_response_length", 0
                ) * step_count + len(llm_response)

                step_count += 1
                cur_query = new_query

                result.append(
                    {
                        "success": success,
                        "original_query": original_query,
                        "step_count": step_count,
                        "cur_query": cur_query,
                        "response": llm_response,
                        "action": action_type,
                    }
                )

                interaction_history["response"] = llm_response
                interaction_history["avg_response_length"] = total_length / (
                    step_count + 1
                )

                if action_type == ActionType.FALLBACK:
                    if len(applied_history) > 1:
                        applied_history = applied_history[:-1]  # remove last one
                    else:
                        self.logger.warning("Cannot fallback with insufficient history")
                else:
                    applied_history.append({
                        "query": new_query,
                        "response": llm_response,
                    })
                interaction_history["action_history"].append({
                    "action": action_type,
                    "query": new_query,
                    "response": llm_response,
                })

                interaction_history["step_count"] = step_count
                interaction_history["n_sensitive_words"] = n_sensitive_words

            except Exception as e:
                self.logger.log_exception(e)

        return result

    def train(self, queries: List[Dict], total_episodes: int = 1000):
        self.logger.info("RL manager start to train")
        episode_count = 0
        update_count = 0
        sample_counts = {i: 0 for i in range(len(queries))}
        failed_samples = []
        
        progress = tqdm(total=total_episodes, desc="Training Episodes")
        
        while episode_count < total_episodes:
            sample_weights = [1.0 / (count + 1) for count in sample_counts.values()]
            chosen_idx = random.choices(range(len(queries)), weights=sample_weights)[0]
            
            query_info = queries[chosen_idx]
            query = query_info['query']
            target = query_info.get('target', 'None')
            category = query_info['category']
            
            try:
                self.logger.info(f"RL manager start to execute query {query}")
                exe_stats = self.exe_episode(query, target, category)
                sample_counts[chosen_idx] += 1
                episode_count += 1
                progress.update(1)
                self.logger.info(f"Episode {episode_count}/{total_episodes} - {exe_stats}")

                if self.agent.buffer.get_size() >= self.batch_size:
                    self.logger.info("RL manager start to update RL Network")
                    update_stats = self.agent.update(self.batch_size, self.update_epochs)
                    update_count += 1
                    
                    self.agent.update_exp_factor(self._calculate_exp_factor(update_count))
                    
                    self.logger.info(f"Update #{update_count} - Stats: {update_stats}")
                    
                    if update_count % self.config.get("save_interval", 10) == 0:
                        self.save_model(update_count)

                        
            except Exception as e:
                failed_samples.append({
                    'idx': chosen_idx,
                    'query': query,
                    'category': category,
                    'error': str(e)
                })
                self.logger.log_exception(e)
                continue
        
        self.save_model("final")
        progress.close()
        
        if failed_samples:
            self.logger.warning(f"Total failed samples: {len(failed_samples)}")
            self.logger.warning("Failed samples details:")
            for sample in failed_samples:
                self.logger.warning(f"Index: {sample['idx']}, Query: {sample['query']}, Error: {sample['error']}")


    def _calculate_exp_factor(self, update_count):
        exp_start_factory = self.config.get("exp_start_factory", 1.0)
        exp_end_factory = self.config.get("exp_end_factory", 0.0)
        exp_decay = self.config.get("exp_decay", 0.99)

        factor = exp_start_factory * np.exp(-exp_decay * update_count)
        factor = max(factor, exp_end_factory)
        return factor

    def _get_model_filepath(self, filename):
        output_dir = self.config.get("output_dir", None)
        if not output_dir:
            raise ValueError("Unknown output dir")

        model_dir = os.path.join(output_dir, "model")
        os.makedirs(model_dir, exist_ok=True)

        return os.path.join(model_dir, f"{filename}.pt")

    def save_model(self, index):
        model_name = self.config.get("model_name", None)
        if not model_name:
            model_name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"{model_name}_{index}"

        model_filepath = self._get_model_filepath(filename)
        self.agent.save_model(model_filepath)

    def _content_reduction(self, llm_response):
        self.logger.info("Generate the response reduction")
        prompt = AgentPrompt.content_reduction(llm_response)
        messages = [
            Message(ROLE_USER, [MessageContent(TYPE_CONTENT, prompt)])
        ]
        result = self.auxiliary.get_api_generate_model().generate(messages)
        self.logger.info(f"Response reduction: {result}")
        return result

    def _record_history(self, record_info):
        cache_dir = self.config.get("cache_dir", None)
        if not cache_dir:
            raise ValueError("Unknwo cache dir")

        if not os.path.exists(cache_dir):
            os.makedirs(cache_dir, exist_ok=True)

        filepath = os.path.join(cache_dir, "record_history.jsonl")
        with open(filepath, "a", encoding="utf-8") as f:
            f.write(json.dumps(record_info, ensure_ascii=False) + "\n")
